# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""tvm.contrib.msc.core.tools.base_tool"""

import os
import copy
import logging
from itertools import product
from typing import List, Iterable, Any, Tuple, Dict, Union
import numpy as np

import tvm
from tvm.contrib.msc.core.ir import MSCGraph, WeightGraph, MSCJoint, WeightJoint, MSCTensor
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.core import utils as msc_utils
from tvm.contrib.msc.core import _ffi_api


class ToolType(object):
    """Enum all msc tool types"""

    BASE = "base"
    WEIGHT = "weight"
    PRUNER = "pruner"
    QUANTIZER = "quantizer"
    DISTILLER = "distiller"
    TRACKER = "tracker"
    ALL = [PRUNER, QUANTIZER, DISTILLER, TRACKER]

    @classmethod
    def all_types(cls) -> List[str]:
        return cls.ALL


class ToolScope(object):
    """Enum all msc tool scope"""

    TEACHER = "teacher"
    STUDENT = "student"


class ToolExecutor(object):
    """Executor for process the tensor

    Parameters
    ----------
    name: str
        The name.
    method: str
        The method for execute.
    config: dict
        The config for execute
    """

    def __init__(self, name: str, method: callable, config: dict = None):
        self._name = name
        self._method = method
        self._config = config or {}

    def __str__(self):
        return "{}({})".format(self._name, self._config)

    def execute(self, *args, **kwargs) -> Any:
        """execute the method

        Parameters
        ----------
        args: list<Any>
            The arguments for run method.
        kwargs: dict<Any>
            The key word arguments for run method.

        Returns
        -------
        plan or tensor:
           The plan generated by method or processed tensor.
        """

        kwargs.update(self._config)
        return self._method(*args, **kwargs)

    def copy(self, name: str = None, method: callable = None, config: dict = None):
        """Copy a executor

        Parameters
        ----------
        name: str
            The name for new executor.
        method: str
            The method for new execute.
        config: dict
            The config for new execute

        Returns
        -------
        new_executor: ToolExecutor
           The copied executor
        """

        new_config = config or {}
        new_config.update({k: v for k, v in self._config.items() if k not in new_config})
        return ToolExecutor(name or self._name, method or self._method, new_config)

    @property
    def method_def(self):
        return {"method_name": self._name, **self._config}

    @property
    def name(self):
        return self._name

    @property
    def config(self):
        return self._config


class ToolStrategy(object):
    """Strategy for process tensor

    Parameters
    ----------
    name: str
        The name.
    tensor_type: str
        The tensor type.
    stage: str
        The init stage
    meta: dict:
        The meta strategy config.
    """

    def __init__(self, name: str, tensor_type: str, stage: str = "default"):
        self._name = name
        self._tensor_type = tensor_type
        self._stage = stage
        self._executors = {}

    def __str__(self):
        return "{}({} @ {}) ".format(self._name, self._tensor_type, self._stage) + "; ".join(
            ["{}:{}".format(k, v) for k, v in self._executors.items()]
        )

    def inspect(self) -> dict:
        """Get inspect of strategy

        Returns
        -------
        inspect: dict
           The inspect of the strategy.
        """

        return {s: str(e) for s, e in self._executors.items()}

    def __call__(self, *args, **kwargs) -> Any:
        return self.apply(*args, **kwargs)

    def apply(self, *args, **kwargs) -> Any:
        """Apply the strategy

        Parameters
        ----------
        args: list<Any>
            The arguments for run method.
        kwargs: dict<Any>
            The key word arguments for run method.

        Returns
        -------
        plan or tensor:
           The plan generated by method or processed tensor.
        """

        return self.get_executor().execute(*args, **kwargs)

    def change_stage(self, stage: str):
        """Change the stage of strategy"""

        self._stage = stage

    def add_executor(self, stage: str, executor: ToolExecutor):
        """Add a executor to strategy

        Parameters
        ----------
        stage: str
            The mark of the executor.
        executor: ToolExecutor
            The executor to process tensor.
        """

        self._executors[stage] = executor
        if not self._stage:
            self._stage = stage

    def get_executor(self, stage: str = None) -> Tuple[callable, dict]:
        """Get executor of current stage

        Parameters
        ----------
        stage: str
            The mark of the executor.

        Returns
        -------
        executor: tuple<callable, dict>
           The method and config to execute strategy
        """

        stage = stage or self._stage
        if stage in self._executors:
            return self._executors[stage]
        return self._executors["default"]

    def get_config(self) -> dict:
        """Get the config of current executor"""

        return self.get_executor().config

    def support_stage(self, stage: str) -> bool:
        """Check if the strategy support a stage

        Parameters
        ----------
        stage: str
            The mark of the executor

        Returns
        -------
        support: bool
           Whether the strategy support the strategy
        """

        return stage in self._executors or "default" in self._executors

    def copy(
        self,
        name: str = None,
        tensor_type: str = None,
        stage: str = None,
        configs: Dict[str, dict] = None,
    ):
        """Copy a strategy

        Parameters
        ----------
        name: str
            The name for new strategy
        tensor_type:
            The tensor type for new strategy
        stage: str
            The init stage for new strategy
        configs: dict<str,dict>
            The method config of new executors.

        Returns
        -------
        new_strategy: ToolStrategy
           The copied strategy
        """

        configs = configs or {}
        strategy = ToolStrategy(
            name or self._name, tensor_type or self._tensor_type, stage or self._stage
        )
        for st_name, executor in self._executors.items():
            new_executor = executor.copy(config=configs.get(st_name, {}))
            strategy.add_executor(st_name, new_executor)
        return strategy


class BaseTool(object):
    """Basic tool of MSC

    Parameters
    ----------
    tag: str
        The tag of tool.
    stage: str
        The stage of tool.
    plan_file: str
        The plan file path.
    strategys: list[dict]
        The strategys of the tool.
    training: bool
        Whether the tool is training.
    cache_processed: bool
        Whether to cache processed tensor.
    options: dict
        The extra options for the tool
    debug_level: int
        The debug level.
    verbose_step: int
        The verbose interval step.
    logger: logging.Logger
        The logger
    """

    def __init__(
        self,
        tag: str,
        stage: str,
        plan_file: str,
        strategys: List[dict],
        training: bool = False,
        cache_processed: bool = True,
        options: dict = None,
        debug_level: int = 0,
        verbose_step: int = 50,
        logger: logging.Logger = None,
    ):
        self._tag = tag
        self._stage = stage
        self._plan_file = plan_file
        if os.path.isfile(plan_file):
            self._plan = msc_utils.load_dict(plan_file)
        else:
            self._plan = {}
        self._meta_strategys, self._strategys = msc_utils.copy_dict(strategys), {}
        self._training = training
        self._cache_processed = cache_processed
        self._options = options or {}
        self._debug_level = debug_level
        self._verbose_step = verbose_step
        self._logger = logger or msc_utils.get_global_logger()
        title = self.tool_mark("APPLY_PLAN" if self._plan else "MAKE_PLAN")
        self._logger.info(msc_utils.msg_block(title, self.setup()))

    def __str__(self):
        msg = "forward[{}] {} graphs, {} weights".format(
            self._forward_cnt, len(self._graphs), len(self._weights)
        )
        return self.tool_mark(msg)

    def setup(self) -> dict:
        """Setup the tool

        Returns
        -------
        info: dict
            The setup info.
        """

        self._tensor_cache = {}
        self._enabled = True
        self._graphs, self._weights = [], {}
        self._graph_id, self._forward_cnt = 0, 0
        self._processed_tensor = {}
        plan_info = self._plan if self._plan and self._debug_level >= 2 else self._plan_file
        return {
            "style": self.tool_style(),
            "cache_processed": self._cache_processed,
            "options": self._options,
            "debug_step({})".format(self._debug_level): self._verbose_step,
            "plan({})".format(len(self._plan)): plan_info,
        }

    def reset(
        self,
        graphs: List[MSCGraph],
        weights: Dict[str, tvm.runtime.Tensor],
        cache_dir: msc_utils.MSCDirectory = None,
    ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]:
        """Reset the tool with graphs and weights

        Parameters
        ----------
        graphs: list<MSCgraph>
            The msc graphs.
        weights: dict<str, tvm.runtime.tensor>
            The weights.
        cache_dir: MSCDirectory
            cache path for save/load info.

        Returns
        -------
        graphs: list<MSCgraph>
            The msc graphs.
        weights: dict<str, tvm.runtime.tensor>
            The weights.
        """

        self._forward_cnt = 0
        self._tensor_cache = {}
        if cache_dir and os.path.isfile(cache_dir.relpath("cache_info.json")):
            cache_info = msc_utils.load_dict(cache_dir.relpath("cache_info.json"))
        else:
            cache_info = {}
        if self.tool_type() in cache_info:
            self.load_cache(cache_dir, cache_info[self.tool_type()])
        self._graphs, self._weights = self._reset(graphs, weights)
        self._strategys = self._parse_strategys(self._meta_strategys)
        if self._strategys:
            title = self.tool_mark("STRATEGYS({})".format(len(self._strategys)))
            strategys_info = {k: v.inspect() for k, v in self._strategys.items()}
            self._logger.info(msc_utils.msg_block(title, strategys_info, width=0))
        return self._graphs, self._weights

    def _reset(
        self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor]
    ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]:
        """Reset the tool

        Parameters
        ----------
        graphs: list<MSCgraph>
            The msc graphs.
        weights: dict<str, tvm.runtime.tensor>
            The weights.

        Returns
        -------
        graphs: list<MSCgraph>
            The msc graphs.
        weights: dict<str, tvm.runtime.tensor>
            The weights.
        """

        return graphs, weights

    def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]:
        """Parse the strategy to get valid strategy

        Parameters
        -------
        strategy_list: list<dict>
            The given strategys.

        Returns
        -------
        strategys: dict<str, ToolStrategy>
            The parsed strategy.
        """

        assert isinstance(strategy_list, list) and all(
            isinstance(s, dict) for s in strategy_list
        ), "ToolStrategy should be given as list of dict"
        assert self._graphs, "graphs are needed to parse strategys"
        all_tensor_names = set(t.name for t in self.get_tensors())
        all_tensor_ids = set(self.get_tensor_ids())
        all_op_types = set(n.optype for n in self.get_nodes())
        all_op_names = set(n.name for n in self.get_nodes())
        strategys = {}

        def _get_method(method_name):
            if "." in method_name:
                method_cls_name, method_name = method_name.split(".")
            else:
                method_cls_name = "default"
            method_cls = msc_utils.get_registered_tool_method(
                self.framework(), self.tool_type(), method_cls_name
            )
            if hasattr(method_cls, method_name):
                return getattr(method_cls, method_name)
            default_cls = msc_utils.get_registered_tool_method(
                MSCFramework.MSC, self.tool_type(), method_cls_name
            )
            if hasattr(default_cls, method_name):
                return getattr(default_cls, method_name)
            method = msc_utils.get_registered_func(method_name)
            assert method, "Can not find method with " + str(method_name)
            return method

        for strategy in strategy_list:
            meta_strategy = msc_utils.copy_dict(strategy)
            for t_type, method_def in meta_strategy["methods"].items():
                if isinstance(method_def, str):
                    method_name, method_kwargs = method_def, {}
                elif isinstance(method_def, dict):
                    assert "method_name" in method_def, "Can not find method_name"
                    method_name = method_def["method_name"]
                    method_kwargs = {k: v for k, v in method_def.items() if k != "method_name"}
                else:
                    raise TypeError(
                        "Only support string and dict as method define, get " + str(method_def)
                    )
                method = _get_method(method_name)
                if "marks" in strategy:
                    assert t_type == "mark", "mark strategy only support mark method, get " + str(
                        meta_strategy
                    )
                    marks = strategy["marks"]
                elif "tensor_names" in strategy:
                    assert (
                        t_type == "tensor"
                    ), "tensor strategy only support tensor method, get " + str(meta_strategy)
                    marks = [t for t in strategy["tensor_names"] if t in all_tensor_names]
                elif "tensor_ids" in strategy:
                    assert (
                        t_type == "tensor"
                    ), "tensor strategy only support tensor method, get " + str(meta_strategy)
                    marks = [t for t in strategy["tensor_ids"] if t in all_tensor_ids]
                elif "op_types" in strategy:
                    op_types = [t for t in strategy["op_types"] if t in all_op_types]
                    marks = ["{}.{}".format(t, t_type) for t in op_types]
                elif "op_names" in strategy:
                    op_names = [t for t in strategy["op_names"] if t in all_op_names]
                    marks = ["{}.{}".format(t, t_type) for t in op_names]
                else:
                    marks = ["default." + str(t_type)]
                for mark, stage in product(marks, strategy.get("stages", ["default"])):
                    if mark not in strategys:
                        strategys[mark] = ToolStrategy(mark, t_type, self._stage)
                    strategys[mark].add_executor(
                        stage, ToolExecutor(method_name, method, copy.deepcopy(method_kwargs))
                    )
        return strategys

    def change_strategys(self, strategy_list: List[dict]):
        """Change the strategys

        Parameters
        -------
        strategy_list: list<dict>
            The given strategys.
        """

        self._meta_strategys = strategy_list

    def change_stage(self, stage: str):
        """Change the stage of tool and strategy"""

        self._stage = stage
        for strategy in self._strategys.values():
            strategy.change_stage(stage)

    def change_logger(self, logger: logging.Logger):
        """Change the logger of tool"""

        self._logger = logger

    def destory(self):
        """Destory tool"""

        self._graphs, self._weights = [], {}

    def export_config(self, config: dict, folder: msc_utils.MSCDirectory) -> dict:
        """Export the config for tool

        Parameters
        -------
        config: dict
            The source config.
        folder: MSCDirectory
            The export folder.

        Returns
        -------
        config: dict
            The exported config.
        """

        plan_file = msc_utils.to_abs_path(config["plan_file"], msc_utils.get_config_dir())
        if os.path.isfile(plan_file):
            return {"plan_file": folder.create_dir("tools").copy(plan_file)}
        return {}

    def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict):
        """Save runner to cache

        Parameters
        -------
        cache_dir: MSCDirectory
            cache path for save/load info
        cache_info: dict
            The cache_info
        """

        return None

    def save_cache(self, cache_dir: msc_utils.MSCDirectory) -> dict:
        """Save runner to cache

        Parameters
        -------
        cache_dir: MSCDirectory
            cache path for save/load info

        Returns
        -------
        cache_info: dict
            The cache_info.
        """

        return {}

    def execute_before_build(self, *args, **kwargs):
        """Execute before model build

        Parameters
        ----------
        args: list<Any>
            The arguments for model build.
        kwargs: dict<Any>
            The key word arguments for model build.
        """

        if self._enabled:
            self._graph_id = self._infer_graph_id(kwargs)
            self._processed_tensor = {}
            if self.on_debug(3, in_forward=False):
                self._logger.debug(self.msg_mark("Start Build", in_forward=False))
            self._execute_before_build(*args, **kwargs)

    def _execute_before_build(self, *args, **kwargs):
        """Execute before model build

        Parameters
        ----------
        args: list<Any>
            The arguments for model build.
        kwargs: dict<Any>
            The key word arguments for model build.
        """

        return None

    def execute_after_build(self, output: Any) -> Any:
        """Execute after model build

        Parameters
        ----------
        output: Any
            The output reference of the model.

        Returns
        -------
        output: Any
           The modified output reference.
        """

        if self._enabled:
            output = self._execute_after_build(output)
            if self.on_debug(3, in_forward=False):
                self._logger.debug(self.msg_mark("End Build", in_forward=False))
        return output

    def _execute_after_build(self, output: Any) -> Any:
        """Execute after model build

        Parameters
        ----------
        output: Any
            The output reference of the model.

        Returns
        -------
        output: Any
           The modified output reference.
        """

        return output

    def execute_before_forward(self, *args, **kwargs):
        """Execute before model forward

        Parameters
        ----------
        args: list<Any>
            The arguments for model forward.
        kwargs: dict<Any>
            The key word arguments for model forward.
        """

        if self._enabled:
            self._graph_id = self._infer_graph_id(kwargs)
            self._processed_tensor = {}
            if self.on_debug(3):
                self._logger.debug(self.msg_mark("Start Forward"))
            self._execute_before_forward(*args, **kwargs)

    def _execute_before_forward(self, *args, **kwargs):
        """Execute before model forward

        Parameters
        ----------
        args: list<Any>
            The arguments for model forward.
        kwargs: dict<Any>
            The key word arguments for model forward.
        """

        return None

    def execute_after_forward(self, output: Any) -> Any:
        """Execute after model forward

        Parameters
        ----------
        output: Any
            The output reference of the model.

        Returns
        -------
        output: Any
           The modified output reference.
        """

        if self._enabled:
            output = self._execute_after_forward(output)
            if self.on_debug(3):
                msg = "End Forward, process {} tensors".format(len(self._processed_tensor))
                self._logger.debug(self.msg_mark(msg))
            self._forward_cnt += 1
        return output

    def _execute_after_forward(self, output: Any) -> Any:
        """Execute after model forward

        Parameters
        ----------
        output: Any
            The output reference of the model.

        Returns
        -------
        output: Any
           The modified output reference.
        """

        return output

    def process_tensor(self, tensor: Any, name: str, consumer: str, scope: str) -> Any:
        """Process tensor

        Parameters
        -------
        tensor: Any
            Tensor in framework
        name: str
            The name of the tensor.
        consumer: str
            The name of the consumer.
        scope: str
            The scope mark teacher| student| null

        Returns
        -------
        tensor: Any
            The processed tensor.
        """

        if not self._enabled:
            return tensor
        if not self._support_scope(scope):
            return tensor
        strategys = self._get_tensor_strategys(name, consumer)
        t_mark = ".".join([s.get_executor().name for s in strategys])
        if scope:
            t_mark += "." + scope
        cached_tensor = self._get_processed(name, consumer, t_mark)
        if cached_tensor is not None:
            self.debug_tensors(name, consumer, t_mark, {"cached": cached_tensor})
            return cached_tensor
        process = self._get_tensor_cache(name, consumer, "process")
        if process is None:
            process = self._check_tensor(name, consumer)
            self._save_tensor_cache(name, consumer, "process", process)
        if not process:
            return tensor
        if isinstance(tensor, dict):
            new_tensor = self._process_tensor(
                msc_utils.copy_dict(tensor), name, consumer, scope, strategys
            )
        else:
            new_tensor = self._process_tensor(tensor, name, consumer, scope, strategys)
        self._save_processed(name, consumer, new_tensor, t_mark)
        if msc_utils.is_array(tensor) and id(new_tensor) != id(tensor):
            tensors = {"org": tensor, "new": new_tensor, "dif": tensor - new_tensor}
            self.debug_tensors(name, consumer, t_mark, tensors)
        elif isinstance(tensor, dict) and len(tensor.get("processed", [])) != len(
            new_tensor.get("processed", [])
        ):
            tensors = {"org": tensor, "new": new_tensor}
            self.debug_tensors(name, consumer, t_mark, tensors)
        return new_tensor

    def _support_scope(self, scope: str) -> bool:
        """Check if the scope si supported

        Parameters
        -------
        scope: str
            The scope mark, should be null or ToolScope

        Returns
        -------
        vaild: bool
            Whether to process the tensor.
        """

        if not scope:
            return True
        return scope != ToolScope.TEACHER

    def _get_processed(self, name: str, consumer: str, strategy_mark: str) -> Any:
        """Get cached processed tensor

        Parameters
        -------
        name: str
            The name of the tensor.
        consumer: str
            The name of the consumer.
        strategy_mark: str
            The sstrategy mark.

        Returns
        -------
        processed_tensor
            The cached processed tensor.
        """

        if self._cache_processed:
            return self._processed_tensor.get(name + "." + strategy_mark)
        return None

    def _save_processed(self, name: str, consumer: str, tensor: Any, strategy_mark: str):
        """Save cached processed tensor

        Parameters
        -------
        name: str
            The name of the tensor.
        consumer: str
            The name of the consumer.
        tensor: Any
            The processed tensor
        strategy_mark: str
            The sstrategy mark.
        """

        if self._cache_processed:
            self._processed_tensor[name + "." + strategy_mark] = tensor
        else:
            self._processed_tensor[self.to_tensor_id(name, consumer)] = None

    def _check_tensor(self, name: str, consumer: str) -> bool:
        """Check if the tensor should be processed

        Parameters
        -------
        name: str
            The name of the tensor.
        consumer: str
            The name of the consumer.

        Returns
        -------
        vaild: bool
            Whether to process the tensor.
        """

        strategys = self._get_tensor_strategys(name, consumer)
        return len(strategys) > 0

    def _process_tensor(
        self, tensor: Any, name: str, consumer: str, scope: str, strategys: List[ToolStrategy]
    ) -> Any:
        """Process tensor

        Parameters
        -------
        tensor: Any
            Tensor in framework
        name: str
            The name of the tensor.
        consumer: str
            The name of the consumer.
        scope: str
            The scope mark teacher| student| null.
        strategys: list<ToolStrategy>
            The strategys for the tensor.

        Returns
        -------
        tensor: Any
            The processed tensor.
        """

        return tensor

    def create_tasks(self, **kwargs) -> List[dict]:
        """Create tasks for gym

        Parameters
        ----------
        kwargs: dict
           The kwargs for create tasks.

        Returns
        -------
        tasks: list<dict>
            The tasks.
        """

        return []

    def config_generate(self, generate_config: Dict[str, Any]) -> Dict[str, Any]:
        """Update the generate configs

        Parameters
        ----------
        generate_config: dict<str, Any>
            The generate_config.

        Returns
        -------
        generate_config: dict<str, Any>
            The updated generate_config.
        """

        return generate_config

    def visualize(self, visual_dir: msc_utils.MSCDirectory):
        """Visualize MSCGraphs

        Parameters
        -------
        visual_dir: MSCDirectory
            Visualize path for saving graph
        """

        return None

    def finalize(self) -> dict:
        """Get the plan"""

        return self._plan

    def enable(self):
        """Enable the tool"""

        self._enabled = True

    def disable(self):
        """Disable the tool"""

        self._enabled = False

    def train(self):
        """Set the tool to train mode"""

        self._training = True

    def eval(self):
        """Set the tool to eval mode"""

        self._training = False

    def to_tensor_id(self, name: str, consumer: str) -> str:
        """Concat name to unique id

        Parameters
        ----------
        name: str
            The name of tensor.
        consumer: str
            The name of consumer.

        Returns
        -------
        tensor_id: str
           The unique name of edge.
        """

        return "{}-c-{}".format(name, consumer)

    def from_tensor_id(self, tensor_id: str) -> Tuple[str]:
        """Split name from unique id

        Parameters
        ----------
        tensor_id: str
           The unique name of edge.

        Returns
        -------
        name: str
            The name of tensor.
        consumer: str
            The name of consumer.
        """

        return tensor_id.split("-c-")

    def is_weight(self, name: str) -> bool:
        """Check if the tensor is weight

        Parameters
        ----------
        name: str
           The name of tensor.

        Returns
        -------
        is_weight: bool
            Whether the name is weight.
        """

        return name in self._weights

    def on_debug(self, debug_level: int = 1, in_forward: bool = True) -> bool:
        """Check if should log

        Parameters
        -------
        debug_level: int
           The given debug_level.
        in_forward: bool
            Whether to check forward_cnt.

        Returns
        -------
        on_debug: bool
            Whether to log debug info.
        """

        if in_forward and self._forward_cnt % self._verbose_step != 0:
            return False
        return self._debug_level >= debug_level

    def tool_mark(self, msg: Any) -> str:
        """Mark the message with tool info

        Parameters
        -------
        msg: str
            The message

        Returns
        -------
        msg: str
            The message with mark.
        """

        return "{}[{}]({} @ {}) {}".format(
            self.tool_type().upper(), self._tag, self.framework(), self._stage, msg
        )

    def msg_mark(self, msg: Any, in_forward: bool = True) -> str:
        """Mark the message with debug info

        Parameters
        -------
        msg:
            The message
        in_forward: bool
            Whether to add forward mark.

        Returns
        -------
        msg: str
            The message with mark.
        """

        mark = "{}({} @ {}) G[{}]".format(
            self.tool_type().upper(), self._tag, self._stage, self._graph_id
        )
        if in_forward:
            mark += ".F[{}]".format(self._forward_cnt)
        return mark + " " + str(msg)

    def debug_tensors(
        self, name: str, consumer: str, t_mark: str, tensors: Dict[str, Any], debug_level: int = 3
    ) -> str:
        """Get the debug tensor info

        Parameters
        -------
        name: str
            The name of tensor.
        consumer: str
            The name of consumer.
        t_mark: str
            The mark of tensor.
        tensors: dict<str,array_like>
            The tensors.
        debug_level: int
           The given debug_level.
        """

        if self.on_debug(debug_level):

            def _t_info(tensor):
                if msc_utils.is_array(tensor):
                    return msc_utils.inspect_array(tensor)
                if isinstance(tensor, dict) and "processed" in tensor:
                    return "{}({} processed)".format(
                        self.find_tensor(name), len(tensor["processed"])
                    )
                return str(tensor)

            msg = "{}-{}({})".format(name, consumer, t_mark)
            tensor_des = "\n  ".join(["{:6s}:{}".format(k, _t_info(v)) for k, v in tensors.items()])
            self._logger.debug("%s\n  %s", self.msg_mark(msg), tensor_des)

    def _infer_graph_id(self, kwargs: dict) -> int:
        """Infer graph id from kwargs

        Parameters
        ----------
        kwargs: dict
           The kwargs for execute.
        """

        if "graph_id" in kwargs:
            return kwargs.pop("graph_id")
        if "graph_name" in kwargs:
            name = kwargs.pop("graph_name")
            for idx, g in enumerate(self._graphs):
                if g.name == name:
                    return idx
        return 0

    def get_nodes(self) -> Iterable[MSCJoint]:
        """Get all the nodes in the graphs.

        Returns
        -------
        nodes: generator<MSCJoint>
            The generator of nodes.
        """

        for g in self._graphs:
            for n in g.get_nodes():
                yield n

    def find_node(self, name: str) -> MSCJoint:
        """Find node by name.

        Parameters
        ----------
        name: string
            The name of the node.

        Returns
        -------
        node: MSCJoint
            The found node.
        """

        for g in self._graphs:
            if g.has_node(name):
                return g.find_node(name)
        raise Exception("Can not find node {} from {} graphs".format(name, len(self._graphs)))

    def get_tensors(self) -> Iterable[MSCTensor]:
        """Get all the tensors in the graphs.

        Returns
        -------
        tensors: generator<MSCTensor>
            The generator of tensors.
        """

        for graph in self._graphs:
            for tensor in graph.get_tensors():
                yield tensor

    def get_tensor_ids(self) -> Iterable[MSCTensor]:
        """Get all the tensor ids in the graphs.

        Returns
        -------
        tensors: generator<MSCTensor>
            The generator of tensor ids.
        """

        for graph in self._graphs:
            for node in graph.get_nodes():
                for tensor in node.get_inputs():
                    yield self.to_tensor_id(tensor.name, node.name)
                for weight in node.get_weights().values():
                    yield self.to_tensor_id(weight.name, node.name)

    def find_tensor(self, t_ref: Union[str, MSCTensor]) -> MSCTensor:
        """Find tensor by tensor ref.

        Parameters
        ----------
        t_ref: string| MSCTensor
            The name of the tensor or tensor.

        Returns
        -------
        node: MSCTensor
            The found tensor.
        """

        t_name = t_ref.name if isinstance(t_ref, MSCTensor) else t_ref
        for g in self._graphs:
            if g.has_tensor(t_name):
                return g.find_tensor(t_name)
        raise Exception("Can not find tensor {} from {} graphs".format(t_name, len(self._graphs)))

    def find_producer(self, t_ref: Union[str, MSCTensor]) -> MSCJoint:
        """Find producer by tensor ref.

        Parameters
        ----------
        t_ref: string| MSCTensor
            The name of the tensor or tensor.

        Returns
        -------
        node: MSCJoint
            The found prducer.
        """

        t_name = t_ref.name if isinstance(t_ref, MSCTensor) else t_ref
        for g in self._graphs:
            if g.has_tensor(t_name):
                return g.find_producer(t_name)
        raise Exception(
            "Can not find producer of {} from {} graphs".format(t_name, len(self._graphs))
        )

    def find_consumers(self, t_ref: Union[str, MSCTensor]) -> List[MSCJoint]:
        """Find consumers by tensor ref.

        Parameters
        ----------
        t_ref: string| MSCTensor
            The name of the tensor or tensor.

        Returns
        -------
        node: list<MSCJoint>
            The found consumers.
        """

        t_name = t_ref.name if isinstance(t_ref, MSCTensor) else t_ref
        for g in self._graphs:
            if g.has_tensor(t_name):
                return g.find_consumers(t_name)
        raise Exception(
            "Can not find consumers of {} from {} graphs".format(t_name, len(self._graphs))
        )

    def get_data(self, name: str) -> np.ndarray:
        """Get the data by name

        Parameters
        -------
        name: str
            The tensor name

        Returns
        -------
        data: np.ndarray
            The data.
        """

        if name in self._weights:
            return msc_utils.cast_array(self._weights[name])
        raise Exception("Can not find data {} from {} weights".format(name, len(self._weights)))

    def _save_tensor_cache(self, name: str, consumer: str, key: str, value: Any) -> Any:
        """Save the data to tensor cache

        Parameters
        -------
        name: str
            The tensor name.
        consumer: str
            The name of the consumer.
        key: str
            The data key.
        value: any
            The value to cache.

        Returns
        -------
        value: any
            The saved value.
        """

        tensor_id = self.to_tensor_id(name, consumer)
        if tensor_id not in self._tensor_cache:
            self._tensor_cache[tensor_id] = {}
        self._tensor_cache[tensor_id][key] = value
        return value

    def _get_tensor_cache(self, name: str, consumer: str, key: str) -> Any:
        """Get the cached tensor data

        Parameters
        -------
        name: str
            The tensor name.
        consumer: str
            The name of the consumer.
        key: str
            The data key.

        Returns
        -------
        value: any
            The cached value.
        """

        tensor_id = self.to_tensor_id(name, consumer)
        if tensor_id not in self._tensor_cache:
            return None
        return self._tensor_cache[tensor_id].get(key)

    def _get_tensor_strategys(self, name: str, consumer: str) -> List[ToolStrategy]:
        """Get the strategys by name and consumer

        Parameters
        -------
        name: str
            The tensor name.
        consumer: str
            The name of the consumer.

        Returns
        -------
        strategys: list<ToolStrategy>
            The strategys for the tensor.
        """

        tensor_id = self.to_tensor_id(name, consumer)
        mark = "strategy.{}".format(self._stage)
        if mark not in self._tensor_cache.get(tensor_id, {}):
            strategys = []

            def _add_strategy(ref):
                if ref in self._strategys and self._strategys[ref].support_stage(self._stage):
                    strategys.append(self._strategys[ref])
                    return True
                return False

            tensor_strategy = self._strategys.get(tensor_id) or self._strategys.get(name)
            if tensor_strategy and tensor_strategy.support_stage(self._stage):
                strategys.append(tensor_strategy)
            elif self.is_weight(name):
                consumer = self.find_node(consumer)
                for w_type in [consumer.weight_type(name), "weights"]:
                    for ref in [consumer.name, consumer.optype, "default"]:
                        if not strategys and _add_strategy(ref + "." + w_type):
                            break
            elif consumer == "exit":
                producer = self.find_producer(name)
                for ref in [producer.name, producer.optype, "exit", "default"]:
                    if _add_strategy(ref + ".output"):
                        break
            else:
                producer = self.find_producer(name)
                for ref in [producer.name, producer.optype, "default"]:
                    if _add_strategy(ref + ".output"):
                        break
                consumer = self.find_node(consumer)
                for ref in [consumer.name, consumer.optype, "default"]:
                    if _add_strategy(ref + ".input"):
                        break
            self._save_tensor_cache(name, consumer, mark, strategys)
        return self._get_tensor_cache(name, consumer, mark)

    def _get_tensor_strategy(self, name: str, consumer: str) -> ToolStrategy:
        """Get the unique strategy by name and consumer

        Parameters
        -------
        name: str
            The tensor name.
        consumer: str
            The name of the consumer.

        Returns
        -------
        strategy: ToolStrategy
            The unique strategy for the tensor.
        """

        strategys = self._get_tensor_strategys(name, consumer)
        if not strategys:
            return None
        assert len(strategys) == 1, "{} should only has 1 strategy, get {}".format(
            self._stage, strategys
        )
        return strategys[0]

    def get_graph(self):
        return self._graphs[self._graph_id]

    @property
    def plan(self):
        return self._plan

    @classmethod
    def tool_type(cls):
        return ToolType.BASE

    @classmethod
    def framework(cls):
        return MSCFramework.MSC

    @classmethod
    def tool_style(cls):
        return "base"

    @classmethod
    def apply_once(cls):
        return False

    @classmethod
    def exportable(cls):
        return True


class WeightTool(BaseTool):
    """Basic tool with weight graphs"""

    def setup(self) -> dict:
        """Setup the tool

        Returns
        -------
        info: dict
            The setup info.
        """

        self._weight_graphs = []
        return super().setup()

    def _reset(
        self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor]
    ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]:
        """Reset the tool

        Parameters
        ----------
        graphs: list<MSCgraph>
            The msc graphs.
        weights: dict<str, tvm.runtime.tensor>
            The weights.

        Returns
        -------
        graphs: list<MSCgraph>
            The msc graphs.
        weights: dict<str, tvm.runtime.tensor>
            The weights.
        """

        graphs, weights = super()._reset(graphs, weights)
        self._main_wtypes, self._relation_wtypes = self._get_wtypes()
        assert self._main_wtypes, "main_wtypes should be given to build weight graphs"
        if self._weight_graphs:
            assert len(graphs) == len(
                self._weight_graphs
            ), "Graphs {} mismatch with weight graphs {}".format(
                len(graphs), len(self._weight_graphs)
            )
        else:
            self._weight_graphs = [
                _ffi_api.WeightGraph(graph, self._main_wtypes, self._relation_wtypes)
                for graph in graphs
            ]
            msg = "build {} weight graphs".format(len(self._weight_graphs))
            self._logger.debug(self.tool_mark(msg))
        if self.on_debug(2, in_forward=False):
            weight_graphs = {g.name: g.inspect() for g in self._weight_graphs}
            title = self.tool_mark("WEIGHT_GRAPHS({})".format(len(weight_graphs)))
            self._logger.debug(msc_utils.msg_block(title, weight_graphs))
        return graphs, weights

    def _get_wtypes(self) -> Tuple[Dict[str, List[str]], Dict[str, str]]:
        """Get the weight types from options

        Returns
        -------
        main_wtypes: dict<str,list<str>>
            The main weight types.
        relation_wtypes: dict<str, str>
            The relation weight types
        """

        raise NotImplementedError("_get_wtypes is not implemented in WeightTool")

    def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict):
        """Save runner to cache

        Parameters
        -------
        cache_dir: MSCDirectory
            cache path for save/load info
        cache_info: dict
            The cache_info
        """

        assert (
            "weight_graphs" in cache_info
        ), "weight_graphs should be given in cache_info, get " + str(cache_info)
        self._weight_graphs = [
            WeightGraph.from_json(cache_dir.relpath(f)) for f in cache_info["weight_graphs"]
        ]
        msg = "load {} weight graphs from {}".format(len(self._weight_graphs), cache_dir)
        self._logger.debug(self.tool_mark(msg))

    def save_cache(self, cache_dir: msc_utils.MSCDirectory) -> dict:
        """Save runner to cache

        Parameters
        -------
        cache_dir: MSCDirectory
            cache path for save/load info

        Returns
        -------
        cache_info: dict
            The cache_info.
        """

        cache_info = {"weight_graphs": [g.name + "_graph.json" for g in self._weight_graphs]}
        with cache_dir:
            for graph, f_path in zip(self._weight_graphs, cache_info["weight_graphs"]):
                with open(f_path, "w") as f_graph:
                    f_graph.write(graph.to_json())
        return cache_info

    def visualize(self, visual_dir: msc_utils.MSCDirectory):
        """Visualize MSCGraphs

        Parameters
        -------
        visual_dir: MSCDirectory
            Visualize path for saving graph
        """

        for w_graph in self._weight_graphs:
            w_graph.visualize(visual_dir.relpath(w_graph.name + ".prototxt"))
        super().visualize(visual_dir)

    def get_w_nodes(self) -> Iterable[WeightJoint]:
        """Get all the weight nodes in the weight_graphs.

        Returns
        -------
        nodes: generator<WeightJoint>
            The generator of weight nodes.
        """

        for g in self._weight_graphs:
            for n in g.get_nodes():
                yield n

    def has_w_node(self, name: str) -> bool:
        """Check if name in weight_graphs.

        Parameters
        ----------
        name: string
            The name of the node.

        Returns
        -------
        has_node: bool
            Whether node in weight_graphs.
        """

        for g in self._weight_graphs:
            if g.has_node(name):
                return True
        return False

    def find_w_node(self, name: str) -> WeightJoint:
        """Find weight node by name.

        Parameters
        ----------
        name: string
            The name of the node.

        Returns
        -------
        node: WeightJoint
            The found node.
        """

        for g in self._weight_graphs:
            if g.has_node(name):
                return g.find_node(name)
        raise Exception("Can not find node {} from graphs".format(name))

    def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]:
        """Get the input output axes

        Parameters
        ----------
        w_node: WeightJoint
            The weight node.

        Returns
        -------
        axes: (int, int)
            The input output axis.
        """

        if w_node.weight.ndim == 1:
            return 0, 0
        if w_node.has_attr("in_axis") and w_node.has_attr("out_axis"):
            return int(w_node.get_attr("in_axis")), int(w_node.get_attr("out_axis"))
        in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O")
        if in_axis >= 0 and out_axis >= 0:
            return in_axis, out_axis
        if w_node.weight.ndim == 2 and w_node.weight.dim_at("N") > 0:
            io_axis = 1 - w_node.weight.layout_of("N")
            return io_axis, io_axis
        if w_node.weight.layout_of("C") >= 0:
            return w_node.weight.layout_of("C"), w_node.weight.layout_of("C")
        raise Exception("Can not infer in_axis/out_axis from " + str(w_node))

    @classmethod
    def tool_type(cls):
        return ToolType.WEIGHT
